package cn.bugstack.openai.executor.model.chatgpt;

import cn.bugstack.openai.exception.OpenAiSdkException;
import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.chatgpt.config.ChatGPTConfig;
import cn.bugstack.openai.executor.model.chatgpt.valobj.ChatGPTCompletionRequest;
import cn.bugstack.openai.executor.model.chatgpt.valobj.ChatGPTImageRequest;
import cn.bugstack.openai.executor.parameter.*;
import cn.bugstack.openai.executor.result.ResultHandler;
import cn.bugstack.openai.session.Configuration;
import com.chatplus.application.common.util.PlusJsonUtils;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * ChatGPT 模型执行器 https://openai.apifox.cn/doc-3222729
 *
 * @author 小傅哥，微信：fustack
 */
public class ChatGPTModelExecutor implements Executor, ParameterHandler<ChatGPTCompletionRequest>, ResultHandler {
    /**
     * 工厂事件
     */
    private final EventSource.Factory factory;
    /**
     * http 客户端
     */
    private final OkHttpClient okHttpClient;

    public ChatGPTModelExecutor(Configuration configuration) {
        this.factory = configuration.createRequestFactory();
        this.okHttpClient = configuration.getOkHttpClient();
    }

    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        return completions(null, null, completionRequest, eventSourceListener);
    }

    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        // 1. 核心参数校验；不对用户的传参做更改，只返回错误信息。
        if (!completionRequest.isStream()) {
            throw new OpenAiSdkException("illegal parameter stream is false!");
        }
        // 3. 转换参数信息
        ChatGPTCompletionRequest chatGPTCompletionRequest = getParameterObject(completionRequest);
        // 4. 构建请求信息
        Request request = new Request.Builder()
                .url(ChatGPTConfig.API_HOST.concat(ChatGPTConfig.V1_CHAT_COMPLETIONS))
                .tag(completionRequest.getTag())
                .post(RequestBody.create(PlusJsonUtils.toJsonString(chatGPTCompletionRequest), MediaType.parse(Configuration.APPLICATION_JSON)))
                .build();
        // 5. 返回事件结果
        return factory.newEventSource(request, eventSourceListener);
    }

    @Override
    public ImageResponse genImages(ImageRequest imageRequest) throws Exception {
        return genImages(null, null, imageRequest);
    }

    @Override
    public ImageResponse genImages(String apiHostByUser, String apiKeyByUser, ImageRequest imageRequest) throws Exception {
        // 1. 统一转换参数
        ChatGPTImageRequest chatGPTImageRequest = ChatGPTImageRequest.builder()
                .n(imageRequest.getN())
                .size(imageRequest.getSize())
                .prompt(imageRequest.getPrompt())
                .model(imageRequest.getModel())
                .style(imageRequest.getStyle())
                .quality(imageRequest.getQuality())
                .responseFormat(imageRequest.getResponseFormat())
                .build();
        // 构建请求信息
        Request request = new Request.Builder()
                .url(ChatGPTConfig.API_HOST.concat(ChatGPTConfig.V1_IMAGES_COMPLETIONS))
                .header("Content-Type", Configuration.APPLICATION_JSON)
                // 封装请求参数信息，如果使用了 Fastjson 也可以替换 ObjectMapper 转换对象
                .post(RequestBody.create(PlusJsonUtils.toJsonString(chatGPTImageRequest), MediaType.parse(Configuration.APPLICATION_JSON)))
                .build();

        Call call = okHttpClient.newCall(request);
        try (Response execute = call.execute()) {
            ResponseBody body = execute.body();
            if (execute.isSuccessful() && body != null) {
                String responseBody = body.string();
                return PlusJsonUtils.parseObject(responseBody, ImageResponse.class);
            } else {
                throw new IOException("Failed to get image response");
            }
        }
    }

    @Override
    public EventSource pictureUnderstanding(PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public EventSource pictureUnderstanding(String apiHostByUser, String apiKeyByUser, PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public ChatGPTCompletionRequest getParameterObject(CompletionRequest completionRequest) {
        // 转换参数
        List<cn.bugstack.openai.executor.model.chatgpt.valobj.Message> chatGPTMessages = new ArrayList<>();
        List<Message> messages = completionRequest.getMessages();
        for (Message message : messages) {
            cn.bugstack.openai.executor.model.chatgpt.valobj.Message messageVO = new cn.bugstack.openai.executor.model.chatgpt.valobj.Message();
            messageVO.setContent(message.getContent());
            messageVO.setName(message.getName());
            messageVO.setRole(message.getRole());
            chatGPTMessages.add(messageVO);
        }
        // 封装参数
        ChatGPTCompletionRequest chatGPTCompletionRequest = new ChatGPTCompletionRequest();
        chatGPTCompletionRequest.setModel(completionRequest.getModel());
        chatGPTCompletionRequest.setTemperature(completionRequest.getTemperature());
        chatGPTCompletionRequest.setTopP(completionRequest.getTopP());
        chatGPTCompletionRequest.setStream(completionRequest.isStream());
        chatGPTCompletionRequest.setMessages(chatGPTMessages);
        chatGPTCompletionRequest.setFunctions(completionRequest.getFunctions());
        chatGPTCompletionRequest.setFunctionCall(completionRequest.getFunctionCall());
        return chatGPTCompletionRequest;
    }

    @Override
    public boolean supportFunction(String model) {
        return false;
    }

    @Override
    public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
        return eventSourceListener;
    }

}
